package kvraft

import (
	"time"
	// "fmt"
	// "strings"
	"../labgob"
	"../labrpc"
	"log"
	"../raft"
	"sync"
	"sync/atomic"
)

const Debug = 0

func DPrintf(format string, a ...interface{}) (n int, err error) {
	if Debug > 0 {
		log.Printf(format, a...)
	}
	return
}

type opType int

const (
	GET opType = iota
	PUT
	APPEND
)

type Op struct {
	// Your definitions here.
	// Field names must start with capital letters,
	// otherwise RPC will break.
	Key   string
	Value string
	OpType    opType // GET, PUT, APPEND
	ClientId int64
	CommandId int64
}

func (op Op) String() string {
	switch op.OpType {
	case GET:
		return "{Get " + op.Key + "}"
	case PUT:
		return "{PUT [" + op.Key + "] " + op.Value + "}"
	case APPEND:
		return "{Append [" + op.Key + "] " + op.Value + "}"
	default:
		return "Invalid Op"
	}
}

type CommandResponse struct {
	Err Err
	Value string
	CommandId int64
	ClientId int64
	CommandTerm int
}

type KVServer struct {
	mu      sync.Mutex
	me      int
	rf      *raft.Raft
	applyCh chan raft.ApplyMsg
	dead    int32 // set by Kill()

	maxraftstate int // snapshot if log grows this big

	// Your definitions here.
	storage map[string]string
	acked map[int64]int64  // clientID -> commandId
	notifyChans map[int]chan *CommandResponse  // log index -> channel
}

// clear all cached response whose commandID less than the new one
// 不需要传递map的指针
// func clearResponse(response ResponseMemo, clientId int64, commandId int64) {
// 	clientReses := response[clientId]
// 	for cid, res := range clientReses {
// 		if res.CommandID < commandId {
// 			delete(clientReses, cid)
// 		}
// 	}
// }

func (kv *KVServer) getChannel(logindex int) chan *CommandResponse {
	kv.mu.Lock()
	defer kv.mu.Unlock()
	if ch, ok := kv.notifyChans[logindex]; ok {
		return ch
	} else {
		return nil
	}
}

func (kv *KVServer) deleteChannel(logindex int) {
	kv.mu.Lock()
	defer kv.mu.Unlock()
	close(kv.notifyChans[logindex])
	delete(kv.notifyChans, logindex)
}

func (kv *KVServer) Get(args *GetArgs, reply *GetReply) {
	// Your code here.
	// chech whether the request is duplicate
	op := Op{Key: args.Key, CommandId: args.CommandId, OpType: GET, ClientId: args.ClientId}
	reply.Err, reply.Value = kv.requestHandler(op)
}

func (kv *KVServer) PutAppend(args *PutAppendArgs, reply *PutAppendReply) {
	// Your code here.

	// start the op
	var optype opType
	if args.Op == "Put" {
		optype = PUT
	} else if args.Op == "Append" {
		optype = APPEND
	} else {
		reply.Err = "Invalid Op"
		return
	}

	op := Op{Key: args.Key, Value: args.Value, OpType: optype, ClientId: args.ClientId, CommandId: args.CommandId}
	reply.Err, _ = kv.requestHandler(op)
}

func (kv *KVServer) requestHandler(req Op) (replyErr Err, value string) {
	// check whether the request is duplicate
	kv.mu.Lock()
	// having each RPC imply that the client has seen the reply for its previous RPC
	if ackedCommandId, ok := kv.acked[req.ClientId]; ok {
		if ackedCommandId >= req.CommandId {
			DPrintf("Server %d: find duplicate request", kv.me)
			if req.OpType == GET {
				value = kv.storage[req.Key]
			}
			kv.mu.Unlock()
			return "", value
		}
	}
	kv.mu.Unlock()

	// start the op
	index, term, isLeader := kv.rf.Start(req)

	if isLeader {
		DPrintf("kvserver leader %d: Op Started.", kv.me)
		kv.mu.Lock()
		ch := make(chan *CommandResponse)
		kv.notifyChans[index] = ch
		// kv.acked[req.ClientId] = req.CommandId-1  // register the client and command
		// clearResponse(kv.response, req.ClientId, req.CommandId-1)  // clear the cached response before the command
		kv.mu.Unlock()
		
		select {
		case res := <- ch:
			// DPrintf("kvserver leader %d: receive response.", kv.me)
			if res.ClientId != req.ClientId || res.CommandId != req.CommandId || res.CommandTerm != term {
				replyErr = "Lost leadership"
			} else {
				replyErr = res.Err
				value = res.Value
			}
		case <-time.After(time.Second):
			replyErr = "Timeout"
		}
		kv.deleteChannel(index)
	} else {
		replyErr = ErrWrongLeader
	}
	return replyErr, value
}

//
// the tester calls Kill() when a KVServer instance won't
// be needed again. for your convenience, we supply
// code to set rf.dead (without needing a lock),
// and a killed() method to test rf.dead in
// long-running loops. you can also add your own
// code to Kill(). you're not required to do anything
// about this, but it may be convenient (for example)
// to suppress debug output from a Kill()ed instance.
//
func (kv *KVServer) Kill() {
	atomic.StoreInt32(&kv.dead, 1)
	kv.rf.Kill()
	// Your code here, if desired.
}

func (kv *KVServer) killed() bool {
	z := atomic.LoadInt32(&kv.dead)
	return z == 1
}

func (kv *KVServer) applyMsgReciever() {
	for !kv.killed() {
		applyMsg := <- kv.applyCh
		if applyMsg.CommandValid {
			if command, ok := applyMsg.Command.(Op); ok {
				kv.mu.Lock()
				// execute the cmd
				cmdRes := new(CommandResponse)

				ackedCmdId, ok := kv.acked[command.ClientId]
				if command.OpType != GET && (!ok || ackedCmdId < command.CommandId) {
					kv.acked[command.ClientId] = command.CommandId
					if command.OpType == PUT {
						kv.storage[command.Key] = command.Value
					} else if command.OpType == APPEND {
						kv.storage[command.Key] += command.Value
					}
				} else if command.OpType == GET {
					cmdRes.Value = kv.storage[command.Key]
				} else {
					DPrintf("kvserver %d: find duplicate applyMsg, Command: %v", kv.me, command)
				}

				cmdRes.ClientId = command.ClientId
				cmdRes.CommandId = command.CommandId
				cmdRes.Err = ""
				cmdRes.CommandTerm = applyMsg.CommandTerm
				
				kv.mu.Unlock()
				// notify the waiting chanel
				// only who has the command index channel
				if waitChan := kv.getChannel(applyMsg.CommandIndex); waitChan != nil {
					waitChan <- cmdRes
				}
			}
		} else {
			// update the state machine
			DPrintf("kvserver %d: state machine updated", kv.me)
			kv.mu.Lock()
			snapshot := raft.DecodeSnapshot(applyMsg.SnapshotData)
			kv.storage = snapshot.Storage
			kv.acked = snapshot.Acked
			kv.mu.Unlock()
		}
	}
}


func (kv *KVServer) snapshotting() {
	if kv.maxraftstate == -1 {
		return
	}
	for ; !kv.killed(); time.Sleep(time.Second) {
		if kv.rf.ExceededMaxRaftState(kv.maxraftstate) {
			kv.rf.BuildInstallSnapshot(kv.storage, kv.acked)
		}
	}
}

//
// servers[] contains the ports of the set of
// servers that will cooperate via Raft to
// form the fault-tolerant key/value service.
// me is the index of the current server in servers[].
// the k/v server should store snapshots through the underlying Raft
// implementation, which should call persister.SaveStateAndSnapshot() to
// atomically save the Raft state along with the snapshot.
// the k/v server should snapshot when Raft's saved state exceeds maxraftstate bytes,
// in order to allow Raft to garbage-collect its log. if maxraftstate is -1,
// you don't need to snapshot.
// StartKVServer() must return quickly, so it should start goroutines
// for any long-running work.
//
func StartKVServer(servers []*labrpc.ClientEnd, me int, persister *raft.Persister, maxraftstate int) *KVServer {
	// call labgob.Register on structures you want
	// Go's RPC library to marshall/unmarshall.
	labgob.Register(Op{})

	kv := new(KVServer)
	kv.me = me
	kv.maxraftstate = maxraftstate

	// You may need initialization code here.

	kv.applyCh = make(chan raft.ApplyMsg)
	kv.rf = raft.Make(servers, me, persister, kv.applyCh)

	if persister.SnapshotSize() > 0 {
		// restore the snapshot
		snapshot := raft.DecodeSnapshot(persister.ReadSnapshot())
		kv.storage = snapshot.Storage
		kv.acked = snapshot.Acked
	} else {
		kv.storage = make(map[string]string)
		kv.acked = make(map[int64]int64)
	}
	kv.notifyChans = make(map[int]chan *CommandResponse)

	// You may need initialization code here.
	go kv.applyMsgReciever()
	// lab3B
	go kv.snapshotting()

	return kv
}
